import pickle
import numpy as np
import pandas as pd
import tensorly as tl
import matplotlib.pyplot as plt
from wordcloud import WordCloud

backend="cupy"
tl.set_backend(backend)
device = 'cuda'

def generate_top_words(topic_word_dist, words, num_tops, top_n):
    '''save top words in each topic to a wordcloud'''
    cloud = WordCloud(background_color='white',
                  width=2600,
                  height=1800,
                  max_words=top_n,
                  colormap='tab10')

    #int(num_tops/3) + 1, 3
    fig, axes = plt.subplots(1, 3, figsize=(60,20), sharex=True, sharey=True)

    for i, ax in enumerate(axes.flatten()):
        fig.add_subplot(ax)
        cloud.generate_from_frequencies(dict(zip(words, 
                                        topic_word_dist[num_tops[i],:] )))
        plt.gca().imshow(cloud)
        plt.gca().set_title('Topic ' + str(i), fontdict=dict(size=16))
        plt.gca().axis('off')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.axis('off')
    plt.margins(x=0, y=0)
    plt.tight_layout()
    plt.savefig("results/wordcloud_covid"+str(num_tops)+".png")
    return

K = 20

tlda = pickle.load(open('data/covid_experiment/num_tops_5_alpha0_0.0001_learning_rate_1e-05_theta_5.005_orthogonality_1000_initialize_first_docs_True_n_eigenvec_20/tlda.obj', 'rb'))
factors_tlda = tlda.unwhitened_factors.get().T

vocab_pandas = pd.read_csv('data/covid_experiment/vocab.csv')['words']
vocab = np.array(vocab_pandas)

top_words_tlda = {}
for i, f in enumerate(factors_tlda):
    top_tlda = np.argpartition(f,-K)[-K:]
    top_words_tlda[i] = list(vocab[top_tlda])

print(top_words_tlda)
df = pd.DataFrame(top_words_tlda)
df.to_csv('results/covid_topics_tlda.csv')

num_tops = 5
top_n    = 30
generate_top_words(factors_tlda,vocab_pandas,[0, 1, 3],top_n)